// [[Rcpp::plugins("cpp11")]]

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <RcppArmadilloExtensions/rmultinom.h>
using namespace Rcpp;
using namespace arma;

// [[Rcpp::export]]
// function to draw multivariate-normal random variables
NumericVector mvrnorm(int n, colvec mu, mat sigma) {
  int d = sigma.n_cols;
  colvec mmu(mu.begin(), d, false);
  mat ssigma(sigma.begin(), d, d, false);
  mat Y = mat(n, d, fill::randn);
  mat result = repmat(mmu, 1, n).t() + Y * chol(ssigma);
  return wrap(result);
}
// vectorized function to compute the normal density
NumericVector my_dnorm(NumericVector x, NumericVector means, NumericVector sds){
  int nn = x.size();
  NumericVector res(nn);
  for(int h = 0; h < nn; h ++) res[h] = R::dnorm(x[h], means[h], sds[h], false);
  return res;
}
// function to draw Dirichlet random variables
NumericVector rdirichlet(NumericVector w){
  colvec ww = as<colvec>(w);
  int ncat = ww.n_rows;
  colvec store(ncat);
  for (int num = 0; num < ncat; num ++) {
    store(num) = as_scalar(randg(1, distr_param(ww(num), 1.0)));
  }
  return wrap(store / sum(store));
}

// [[Rcpp::export]]
List FMR(colvec y_r,  // rating-based outcomes
         colvec y_c,  // choice-based outcomes
         mat x,  // independent variables
         colvec id,  // respondent ids
         int l,  // number of respondents
         int g,  // number of latent groups
         colvec b0_r,  // prior means for rating coefficients
         mat B0_r,  // prior variance matrix for rating coefficients
         colvec b0_c,  // prior means for choice coefficients
         mat B0_c,  // prior variance matrix for choice coefficients
         double v0_r,  // prior degrees of freedom for the error variance of rating
         double s0_r,  // prior scale for the error variance of rating
         double v0_c,  // prior degrees of freedom for the error variance of choice
         double s0_c,  // prior scale for the error variance of choice
         colvec w0,  // prior concentration parameter
         int burnin,  // number of burn-in iterations
         int iter,  // number of sampling iterations
         int thin  // thinning interval
) {
  int i, j, k, nmember;
  int n = y_r.n_rows; // total number of observations
  int m = x.n_cols + 1; // number of attribute-levels + an intercept
  colvec onevec = colvec(n, fill::ones);
  mat X = join_rows(onevec, x); // design matrix
  colvec initbeta = colvec(m, fill::randn) * 0.1; // initial values for coefficients
  mat beta_r = mat(m, g); // coefficients for each group (rating)
  for (k = 0; k < g; k ++) {
    beta_r.col(k) = initbeta;
  }
  mat beta_c = mat(m, g); // coefficients for each group (choice)
  for (k = 0; k < g; k ++) {
    beta_c.col(k) = initbeta;
  }
  double initsigma = 1 + R::rnorm(0.0, 0.1); // initial value for error variance
  colvec sigma_r = colvec(g); // error variance (rating)
  colvec sigma_c = colvec(g); // error variance (choice)
  sigma_r.fill(initsigma);
  sigma_c.fill(initsigma);
  mat S(l, g); // matrix for the group assignment of each individual
  for (k = 0; k < l; k ++) {
    S.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(colvec(g, fill::ones) / g))));
  }
  mat S2(n, g); // matrix for the group assignment of each observation 
  for (k = 0; k < n; k ++) {
    S2.row(k) = S.row(id(k) - 1);
  }
  colvec muvec = colvec(n);
  colvec sdvec = colvec(n);
  mat lk_r = mat(n, g); // matrix to store the likelihood of each observation-group
  mat lk_c = mat(n, g); // matrix to store the likelihood of each observation-group
  mat XX = mat(m, m);
  mat invXX = mat(m, m);
  colvec betahat = colvec(m);
  mat B1 = mat(m, m);
  colvec b1 = colvec(m);
  double v0s0_r = v0_r * s0_r * s0_r;
  double v0s0_c = v0_c * s0_c * s0_c;
  double v1div2, SS, r, v1s1div2, sigma2;
  mat sigma2B1 = mat(m, m);
  colvec eta = w0;
  colvec ngroup = colvec(g);
  cube betasample_r = cube(m, g, iter);
  cube betasample_c = cube(m, g, iter);
  mat sigmasample_r = mat(iter, g);
  mat sigmasample_c = mat(iter, g);
  mat etasample = mat(iter, g);
  cube Ssample = cube(l, g, iter);
  mat LLsample_r = mat(iter, n);
  mat LLsample_c = mat(iter, n);
  colvec Sprop = colvec(g);
  // burn-in period
  for (i = 0; i < burnin; i ++) {
    // sample weight parameters
    for (k = 0; k < g; k ++) {
      ngroup(k) = sum(S.col(k));
    }
    eta = as<colvec>(rdirichlet(as<NumericVector>(wrap(w0 + ngroup))));
    // sample rating coefficients
    for (k = 0; k < g; k ++) {
      nmember = sum(S2.col(k));
      uvec index = find(S2.col(k) == 1);
      colvec groupY = colvec(nmember);
      mat groupX = mat(nmember, m);
      groupY = y_r.rows(index);
      groupX = X.rows(index);
      XX = trans(groupX) * groupX;
      invXX = inv(XX);
      betahat = invXX * trans(groupX) * groupY;
      B1 = inv(inv(B0_r) + XX);
      b1 = B1 * (inv(B0_r) * b0_r + XX * betahat);
      v1div2 = (v0_r + nmember) / 2;
      SS = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
      r = as_scalar(trans(b0_r - betahat) * inv(B0_r + invXX) * (b0_r - betahat));
      v1s1div2 = 2 / (v0s0_r + SS + r);
      sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
      sigma_r(k) = sqrt(sigma2);
      sigma2B1 = sigma2 * B1;
      beta_r.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
    }
    // sample choice coefficients
    for (k = 0; k < g; k ++) {
      nmember = sum(S2.col(k));
      uvec index = find(S2.col(k) == 1);
      colvec groupY = colvec(nmember);
      mat groupX = mat(nmember, m);
      groupY = y_c.rows(index);
      groupX = X.rows(index);
      XX = trans(groupX) * groupX;
      invXX = inv(XX);
      betahat = invXX * trans(groupX) * groupY;
      B1 = inv(inv(B0_c) + XX);
      b1 = B1 * (inv(B0_c) * b0_c + XX * betahat);
      v1div2 = (v0_c + nmember) / 2;
      SS = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
      r = as_scalar(trans(b0_r - betahat) * inv(B0_r + invXX) * (b0_r - betahat));
      v1s1div2 = 2 / (v0s0_c + SS + r);
      sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
      sigma_c(k) = sqrt(sigma2);
      sigma2B1 = sigma2 * B1;
      beta_c.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
    }
    // likelihood of each observation-groups
    for (k = 0; k < g; k ++) {
      muvec = X * beta_r.col(k);
      sdvec.fill(sigma_r(k));
      lk_r.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_r)), 
                               as<NumericVector>(wrap(muvec)), 
                               as<NumericVector>(wrap(sdvec))));
      muvec = X * beta_c.col(k);
      sdvec.fill(sigma_c(k));
      lk_c.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_c)), 
                               as<NumericVector>(wrap(muvec)), 
                               as<NumericVector>(wrap(sdvec))));
    }
    // assign each respondent to latent groups
    for (k = 0; k < l; k ++) { 
      uvec index = find(id == k + 1);
      colvec likelihood = trans(prod(lk_r.rows(index))) % trans(prod(lk_c.rows(index))) % eta;
      S.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(likelihood / sum(likelihood)))));
    }
    for (k = 0; k < n; k ++) {
      S2.row(k) = S.row(id(k) - 1);
    }
    if (i % 10000 == 0) {
      cout << i + 1 << "th iteration has finished \n";
      beta_r.print("Burn-in, current beta (rating):");
      beta_c.print("Burn-in, current beta (choice):");
      sigma_r.print("Burn-in, current sigma (rating):");
      sigma_c.print("Burn-in, current sigma (choice):");
      for (k = 0; k < g; k ++) {
        Sprop(k) = mean(S.col(k));  // group proportions
      }
      Sprop.print("Burn-in, current proportion of the groups:");
    }
  }
  // sampling period
  for (i = 0; i < iter; i ++) {
    for (j = 0; j < thin; j ++) {
      // sample weight parameters
      for (k = 0; k < g; k ++) {
        ngroup(k) = sum(S.col(k));
      }
      eta = as<colvec>(rdirichlet(as<NumericVector>(wrap(w0 + ngroup))));
      // sample rating coefficients
      for (k = 0; k < g; k ++) {
        nmember = sum(S2.col(k));
        uvec index = find(S2.col(k) == 1);
        colvec groupY = colvec(nmember);
        mat groupX = mat(nmember, m);
        groupY = y_r.rows(index);
        groupX = X.rows(index);
        XX = trans(groupX) * groupX;
        invXX = inv(XX);
        betahat = invXX * trans(groupX) * groupY;
        B1 = inv(inv(B0_r) + XX);
        b1 = B1 * (inv(B0_r) * b0_r + XX * betahat);
        v1div2 = (v0_r + nmember) / 2;
        SS = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
        r = as_scalar(trans(b0_r - betahat) * inv(B0_r + invXX) * (b0_r - betahat));
        v1s1div2 = 2 / (v0s0_r + SS + r);
        sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
        sigma_r(k) = sqrt(sigma2);
        sigma2B1 = sigma2 * B1;
        beta_r.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
      }
      // sample choice coefficients
      for (k = 0; k < g; k ++) {
        nmember = sum(S2.col(k));
        uvec index = find(S2.col(k) == 1);
        colvec groupY = colvec(nmember);
        mat groupX = mat(nmember, m);
        groupY = y_c.rows(index);
        groupX = X.rows(index);
        XX = trans(groupX) * groupX;
        invXX = inv(XX);
        betahat = invXX * trans(groupX) * groupY;
        B1 = inv(inv(B0_c) + XX);
        b1 = B1 * (inv(B0_c) * b0_c + XX * betahat);
        v1div2 = (v0_c + nmember) / 2;
        SS = as_scalar(trans(groupY - groupX * betahat) * (groupY - groupX * betahat));
        r = as_scalar(trans(b0_r - betahat) * inv(B0_r + invXX) * (b0_r - betahat));
        v1s1div2 = 2 / (v0s0_c + SS + r);
        sigma2 = 1 / R::rgamma(v1div2, v1s1div2);
        sigma_c(k) = sqrt(sigma2);
        sigma2B1 = sigma2 * B1;
        beta_c.col(k) = as<colvec>(mvrnorm(1, b1, sigma2B1));
      }
      // likelihood of each observation-groups
      for (k = 0; k < g; k ++) {
        muvec = X * beta_r.col(k);
        sdvec.fill(sigma_r(k));
        lk_r.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_r)), 
                                 as<NumericVector>(wrap(muvec)), 
                                 as<NumericVector>(wrap(sdvec))));
        muvec = X * beta_c.col(k);
        sdvec.fill(sigma_c(k));
        lk_c.col(k) = as<colvec>(my_dnorm(Rcpp::as<NumericVector>(wrap(y_c)), 
                                 as<NumericVector>(wrap(muvec)), 
                                 as<NumericVector>(wrap(sdvec))));
      }
      // assign each respondent to latent groups
      for (k = 0; k < l; k ++) {
        uvec index = find(id == k + 1);
        colvec likelihood = trans(prod(lk_r.rows(index))) % trans(prod(lk_c.rows(index))) % eta;
        S.row(k) = as<rowvec>(Rcpp::RcppArmadillo::rmultinom(1, as<NumericVector>(wrap(likelihood / sum(likelihood)))));
      }
      for (k = 0; k < n; k ++) {
        S2.row(k) = S.row(id(k) - 1);
      }
    }
    // record parameters
    betasample_r.slice(i) = beta_r;
    betasample_c.slice(i) = beta_c;
    sigmasample_r.row(i) = trans(sigma_r);
    sigmasample_c.row(i) = trans(sigma_c);
    etasample.row(i) = trans(eta);
    Ssample.slice(i) = S;
    // individual log-likelihoods
    for (k = 0; k < n; k ++) {
      LLsample_r(i, k) = log(accu(lk_r.row(k) % S.row(id(k) - 1)));
      LLsample_c(i, k) = log(accu(lk_c.row(k) % S.row(id(k) - 1)));
    }
    if (i % 10000 == 0) {
      cout << i + 1 << "th iteration has finished \n";
      beta_r.print("Sampling, current beta (rating):");
      beta_c.print("Sampling, current beta (choice):");
      sigma_r.print("Sampling, current sigma (rating):");
      sigma_c.print("Sampling, current sigma (choice):");
      // group proportions
      for (k = 0; k < g; k ++) {
        Sprop(k) = mean(S.col(k));
      }
      Sprop.print("Sampling, current proportion of the groups:");
    }
  }
  mat LLsample = join_rows(LLsample_r, LLsample_c);
  return List::create(Named("beta_r") = wrap(betasample_r), 
                      Named("beta_c") = wrap(betasample_c), 
                      Named("sigma_r") = wrap(sigmasample_r), 
                      Named("sigma_c") = wrap(sigmasample_c), 
                      Named("eta") = wrap(etasample), 
                      Named("S") = wrap(Ssample), 
                      Named("loglik") = wrap(LLsample));
}